import pandas as pd
from transformers import pipeline, logging
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, balanced_accuracy_score, precision_recall_fscore_support, classification_report
from datasets import load_dataset, DatasetDict, Dataset, disable_progress_bars
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import numpy as np
disable_progress_bars()
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error()

###########################
## Set device from CLI args
###########################
import os
import argparse
def get_device():
    # Allow --device on the command line
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default=None, help="Device to use: cpu, mps, cuda")
    args, _ = parser.parse_known_args()  # ignore unknown args so script still works normally
    # Priority: CLI argument > DEVICE environment variable > default=cpu
    return args.device or os.getenv("DEVICE", "cuda")
DEVICE = get_device()
print(f"[INFO] Using DEVICE={DEVICE}")
# PyTorch setup
import torch
if DEVICE == "cpu":
    device = torch.device("cpu")
elif DEVICE == "mps":
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
elif DEVICE == "cuda":
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
    raise ValueError(f"Unknown device {DEVICE}")
print(f"[INFO] Torch device set to {device}")

###########################
## Data, functions, etc.
###########################
df = pd.read_csv('./data/covid_tweets_labeled.csv')
df = df[['text', 'non_comp']]
df['hypothesis'] = 'The author of this tweet does not believe COVID is dangerous.'
df.rename({'text':'premise', 'non_comp':'entailment'}, axis = 1, inplace = True)
df['entailment'].replace({0:1, 1:0}, inplace = True)

covid = pd.read_csv('./data/covid_classified_tweets.csv')

training_directory ='fewshot'
modname = "mlburnham/Political_DEBATE_DeBERTa_large_v1.1"
# instantiate model
#model = AutoModelForSequenceClassification.from_pretrained(modname, num_labels = 2, ignore_mismatched_sizes=True)

#################
## Few-shot Train
#################
training_args = TrainingArguments(output_dir=training_directory,
    logging_dir=f'{training_directory}/logs',
    lr_scheduler_type= "linear",
    group_by_length=False,
    learning_rate = 9e-6,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 16,
    gradient_accumulation_steps = 1, 
    num_train_epochs=5,
    warmup_ratio=0.06,  
    weight_decay=0.01, 
    fp16=True,
    fp16_full_eval=True,
    eval_strategy="no",
    seed=1,
    save_strategy="no",
    dataloader_num_workers = 0,
    disable_tqdm=True
)

tokenizer = AutoTokenizer.from_pretrained(modname)

def model_init():
    return AutoModelForSequenceClassification.from_pretrained(modname, num_labels = 2, ignore_mismatched_sizes=True)

def tokenize_function(docs):
    return tokenizer(docs['premise'], docs['hypothesis'], padding = 'max_length', truncation = True)
    
train = df.sample(25, random_state = 1)
# Create validation set with remaining instances
val = df[~df.index.isin(train.index)]

# Create a DatasetDict with train and validation splits
ds = DatasetDict({'train': Dataset.from_pandas(train, preserve_index=False), 'validation':Dataset.from_pandas(val, preserve_index=False)})
# Tokenize the dataset
dstok = ds.map(tokenize_function, batched = True)
# Rename 'entailment' column to 'label'
dstok = dstok.rename_columns({'entailment':'label'})
# Define label mapping
id2label = {0: "entailment", 1: "not_entailment"}

# Initialize the Trainer
trainer = Trainer(
    model_init = model_init,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dstok['train'],
    eval_dataset=dstok['validation'],
    compute_metrics=lambda x: compute_metrics_standard(x, label_text_alphabetical=list(model.config.id2label.values()))
)

# Train the model
trainer.train()

###################
## Label All Tweets
###################
pipe = pipeline(task = 'zero-shot-classification', model = trainer.model, tokenizer = trainer.tokenizer,
                    batch_size = 16, device = 'cuda', 
                    max_length = 512, truncation = True)

res = pipe(list(covid['text'].str.lower()), ['The author of this tweet does not believe COVID is dangerous'], hypothesis_template = '{}.', multi_label = True)
labs = [result['scores'][0] for result in res]
covid['debate'] = labs
covid['debate'] = covid['debate'].round()
covid.to_csv('./data/covid_classified_tweets.csv', index = False)